package io.jpom.plugin.netty;

import cn.hutool.cache.impl.TimedCache;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.core.util.URLUtil;
import cn.hutool.extra.ssh.ChannelType;
import cn.hutool.extra.ssh.JschUtil;
import cn.jiangzeyin.common.DefaultSystemLog;
import cn.jiangzeyin.common.JsonMessage;
import cn.jiangzeyin.common.spring.SpringUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.Session;
import com.jcraft.jsch.SftpATTRS;
import com.jcraft.jsch.SftpException;
import io.jpom.model.data.SshModel;
import io.jpom.model.data.UserModel;
import io.jpom.service.node.ssh.SshService;
import io.jpom.service.user.UserService;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.multipart.Attribute;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder;
import io.netty.handler.codec.http.multipart.InterfaceHttpData;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.stream.ChunkedStream;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.springframework.http.MediaType;

import java.io.IOException;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE;
import static io.netty.handler.codec.http.HttpResponseStatus.*;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;

/**
 * 下载handler
 *
 * @author myzf
 * @date 2019/8/11
 */
public class FileServerHandler extends SimpleChannelInboundHandler<Object> {
    /**
     * 所有活动channel
     */
    private static ChannelGroup channels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

    private static final TimedCache<String, JSONObject> DOWNLOAD = new TimedCache<>(TimeUnit.MINUTES.toMillis(1));

    private static final Map<String, Set<String>> USER_DOWNLOAD = new HashMap<>();

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (msg instanceof FullHttpRequest) {
            FullHttpRequest request = (FullHttpRequest) msg;
            //检测解码情况
            if (!request.decoderResult().isSuccess()) {
                sendError(ctx, BAD_REQUEST);
                return;
            }
            this.handlerReq(ctx, request);
        } else if (msg instanceof WebSocketFrame) {
            TextWebSocketFrame textWebSocketFrame = (TextWebSocketFrame) msg;
            this.handlerSocket(ctx, textWebSocketFrame);
        }
    }

    private void handlerReq(ChannelHandlerContext ctx, FullHttpRequest request) throws IOException {
        Map<String, String> parse = parse(request);
        //获取请求参数 共下面页面单个下载用
        String reqId = parse.get("reqId");
        if (StrUtil.isEmpty(reqId)) {
            sendError(ctx, NOT_FOUND);
            return;
        }
        JSONObject jsonObject = DOWNLOAD.get(reqId);
        if (jsonObject == null) {
            sendError(ctx, NOT_FOUND);
            return;
        }
        SshModel sshModel = jsonObject.getObject("ssh", SshModel.class);
        String allPath = jsonObject.getString("allPath");
        //
        ChannelId channelId = jsonObject.getObject("channelId", ChannelId.class);
        Channel channel = channels.find(channelId);
        if (channel == null) {
            sendError(ctx, NOT_FOUND);
            return;
        }
        UserModel user = jsonObject.getObject("user", UserModel.class);
        try {
            Session session = JschUtil.openSession(sshModel.getHost(), sshModel.getPort(), sshModel.getUser(), sshModel.getPassword());
            ChannelSftp channelSftp = (ChannelSftp) JschUtil.openChannel(session, ChannelType.SFTP);
            SftpATTRS attr = channelSftp.stat(allPath);
            long fileSize = attr.getSize();
            PipedInputStream pipedInputStream = new PipedInputStream();
            PipedOutputStream pipedOutputStream = new PipedOutputStream(pipedInputStream);
            ThreadUtil.execute(() -> {
                try {
                    channelSftp.get(allPath, pipedOutputStream);
                } catch (SftpException e) {
                    DefaultSystemLog.getLog().error("下载异常", e);
                }
                IoUtil.close(pipedOutputStream);
            });
            // 缓存
            Set<String> stringSet = USER_DOWNLOAD.computeIfAbsent(user.getId(), (Function<String, HashSet<String>>) s -> new HashSet<>());
            stringSet.add(allPath);
            //
            HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
            HttpUtil.setContentLength(response, fileSize);

            response.headers().set(CONTENT_TYPE, MediaType.APPLICATION_OCTET_STREAM_VALUE);
            // 设定默认文件输出名
            String fileName = URLUtil.encode(FileUtil.getName(allPath));
            response.headers().add("Content-disposition", "attachment; filename=" + fileName);

            ctx.write(response);
            ChannelFuture sendFileFuture = ctx.write(new HttpChunkedInput(new ChunkedStream(pipedInputStream)), ctx.newProgressivePromise());
            Progress progress = new Progress(channel, reqId, session, channelSftp, fileSize);
            sendFileFuture.addListener(progress);

            ChannelFuture lastContentFuture = ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT);

            if (!HttpUtil.isKeepAlive(request)) {
                lastContentFuture.addListener(ChannelFutureListener.CLOSE);
            }
        } catch (Exception e) {
            DefaultSystemLog.getLog().error("下载失败", e);
            sendError(ctx, INTERNAL_SERVER_ERROR);
        }
    }

    private void handlerSocket(ChannelHandlerContext ctx, TextWebSocketFrame textWebSocketFrame) throws IOException {
        // 获取客户端传输过来的消息
        String content = textWebSocketFrame.text();
        Channel channel = ctx.channel();
        if (StrUtil.isEmpty(content)) {
            channel.close();
            return;
        }
        // 1. 获取客户端发来的消息
        JSONObject jsonObject = JSON.parseObject(content);
        String event = jsonObject.getString("event");
        if ("download".equals(event)) {
            UserService userService = SpringUtil.getBean(UserService.class);
            //
            String userId = jsonObject.getString("userId");
            UserModel userModel = userService.checkUser(userId);
            if (userModel == null) {
                sendSocketMsg(ctx, new JsonMessage(800, "用户信息获取失败"));
                return;
            }
            //
            String id = jsonObject.getString("id");
            SshService sshService = SpringUtil.getBean(SshService.class);
            SshModel sshModel = sshService.getItem(id);
            if (sshModel == null) {
                sendSocketMsg(ctx, new JsonMessage(405, "没有对应的ssh信息"));
                return;
            }
            String path = jsonObject.getString("path");
            List<String> fileDirs = sshModel.getFileDirs();
            //
            if (StrUtil.isEmpty(path) || !fileDirs.contains(path)) {
                sendSocketMsg(ctx, new JsonMessage(405, "非法路径"));
                return;
            }
            String name = jsonObject.getString("name");
            if (StrUtil.isEmpty(name)) {
                sendSocketMsg(ctx, new JsonMessage(405, "没有name参数"));
                return;
            }
            // 防止越级   ../
            try {
                FileUtil.file(path, URLUtil.encode(name));
            } catch (Exception e) {
                DefaultSystemLog.getLog().error("非法路径", e);
                sendSocketMsg(ctx, new JsonMessage(405, "非法请求下载"));
                return;
            }
            String allPath = FileUtil.normalize(path + "/" + name);
            Set<String> strings = USER_DOWNLOAD.get(userModel.getId());
            if (strings != null && strings.contains(allPath)) {
                sendSocketMsg(ctx, new JsonMessage(405, "此文件正在下载中"));
                return;
            }
            // 下载请求监测成功
            String reqId = IdUtil.fastSimpleUUID();
            JSONObject data = new JSONObject();
            data.put("ssh", sshModel);
            data.put("user", userModel);
            data.put("channelId", channel.id());
            data.put("allPath", allPath);
            DOWNLOAD.put(reqId, data);
            // 添加
            channels.add(channel);
            // 响应给客户端
            JSONObject result = new JSONObject();
            result.put("reqId", reqId);
            String name1 = FileUtil.getName(name);
            result.put("name", name1);
            result.put("event", "download");
            sendSocketMsg(ctx, new JsonMessage<>(200, "", result));
        }
    }

    private void sendSocketMsg(ChannelHandlerContext channelHandlerContext, JsonMessage jsonMessage) {
        channelHandlerContext.channel().writeAndFlush(new TextWebSocketFrame(jsonMessage.toString()));
    }

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {

    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        channels.remove(ctx.channel());
    }


    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        cause.printStackTrace();
        if (ctx.channel().isActive()) {
            sendError(ctx, INTERNAL_SERVER_ERROR);
        }
        ctx.channel().close();
        channels.remove(ctx.channel());
    }


    private static void sendError(ChannelHandlerContext ctx, HttpResponseStatus status) {
        FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, status, Unpooled.copiedBuffer("Failure: " + status + "\r\n", CharsetUtil.UTF_8));
        response.headers().set(CONTENT_TYPE, "text/plain; charset=UTF-8");
        ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
    }

    /**
     * 解析netty
     * 请求参数
     *
     * @param request req
     * @return 包含所有请求参数的键值对, 如果没有参数, 则返回空Map
     * @throws IOException io
     */
    private Map<String, String> parse(FullHttpRequest request) throws IOException {
        HttpMethod method = request.method();
        Map<String, String> parmMap = new HashMap<>();
        if (HttpMethod.GET == method) {
            // 是GET请求
            QueryStringDecoder decoder = new QueryStringDecoder(request.uri());
            decoder.parameters().forEach((key, value) -> {
                // entry.getValue()是一个List, 只取第一个元素
                parmMap.put(key, value.get(0));
            });
        } else if (HttpMethod.POST == method) {
            // 是POST请求
            HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(request);
            decoder.offer(request);
            List<InterfaceHttpData> parmList = decoder.getBodyHttpDatas();
            for (InterfaceHttpData parm : parmList) {
                Attribute data = (Attribute) parm;
                parmMap.put(data.getName(), data.getValue());
            }
        } else {
        }
        return parmMap;
    }
}